/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef FUSION_ENGINE_OPTIMIZER_ADAPTER_TBE_ADAPTER_TBE_OP_STORE_ADAPTER_H_
#define FUSION_ENGINE_OPTIMIZER_ADAPTER_TBE_ADAPTER_TBE_OP_STORE_ADAPTER_H_

#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "adapter/adapter_itf/op_store_adapter.h"
#include "adapter/tbe_adapter/tbe_info/tbe_info_assembler.h"
#include "adapter/tbe_adapter/tbe_info/tbe_single_op_info_assembler.h"
#include "common/plugin_manager.h"
#include "common/scope_allocator.h"
#include "graph_optimizer/graph_optimize_register_error_codes.h"
#include "tensor_engine/fusion_api.h"

namespace fe {
class TbeOpStoreAdapter;
using TbeOpStoreAdapterPtr = std::shared_ptr<TbeOpStoreAdapter>;
using PluginManagerPtr = std::shared_ptr<fe::PluginManager>;
using TbeInfoAssemblerPtr = std::shared_ptr<fe::TbeInfoAssembler>;
using TbeSingleOpInfoAssemblerPtr = std::shared_ptr<fe::TbeSingleOpInfoAssembler>;
using TbeOpInfoPtr = std::shared_ptr<te::TbeOpInfo>;
using ScopeAllocatorPtr = std::shared_ptr<ScopeAllocator>;

class TbeOpStoreAdapter : public OpStoreAdapter {
 public:
  /* There are two versions of CompileOp, this one does not care about the
   * compile strategy. */
  Status CompileOp(ScopeNodeIdMap &fusion_nodes_map, std::map<int64_t, std::string> &json_path_map,
                   std::vector<ge::NodePtr> &buff_fus_compile_failed_nodes,
                   const std::vector<ge::NodePtr> &buff_fus_to_del_nodes) override;
  /*
   *  @ingroup fe
   *  @brief   compile fused op and single op, and generate .o and json files
   *  @param   [in]  fusion_nodes_map  op id and fused sub-graph
   *  @param   [out] json_file_map_    keep path of .o and json of each op
   *  @return  SUCCESS or FAILED
   */
  Status CompileOp(CompileInfoParam &compile_info) override;

  /*
   *  @ingroup fe
   *  @brief   pre-compile and return pattern of op
   *  @return  SUCCESS or FAILED
   */
  Status PreCompileOp(vector<PreCompileNodePara> &compile_para_vec) override;
  /*
   *  @ingroup fe
   *  @brief   initial resources needed by TbeCompilerAdapter, such as dlopen so
   * files
   *           and load function symbols etc.
   *  @return  SUCCESS or FAILED
   */
  Status Initialize(const std::map<std::string, std::string> &options, const std::string &engine_name) override;
  Status InitializeInner(const std::map<std::string, std::string> &options, const std::string &engine_name);
  Status InitializeInnerHelp();

  /*
   *  @ingroup fe
   *  @brief   finalize resources initialized in Initialize function,
   *           such as dclose so files etc.
   *  @return  SUCCESS or FAILED
   */
  Status Finalize() override;

  bool CheckSupport(const ge::OpDesc &op_desc, OpKernelInfoPtr op_kernel_info_ptr, std::string &reason) override;

  Status SelectOpFormat(const ge::OpDesc &op_desc, const OpKernelInfoPtr &op_kernel_info_ptr,
                        const HeavyFormatInfo &heavy_format_info, string &op_format_dtype_str) override;

  Status OpBuilder(ge::NodePtr node_ptr) override;

  Status UpdateTensorByMixPrecisionMode(ge::OpDescPtr &op_desc, OpKernelInfoPtr &op_kernel_info_ptr);

 private:
  PluginManagerPtr plugin_manager_ptr{nullptr};

  // function wrt TBE API
  function<bool(const te::TbeOpInfo &, std::string &)> SelectTbeOpFormat{nullptr};
  function<bool(te::TbeOpInfo &, te::CheckSupportedResult &, string &reason)> CheckTbeSupported{nullptr};
  function<bool(te::TbeOpInfo &, uint64_t, uint64_t)> PreBuildTbeOp{nullptr};
  function<te::OpBuildResCode(std::vector<ge::Node *>, ge::OpDescPtr, const std::vector<ge::NodePtr> &, uint64_t,
                              uint64_t, const std::string &)> TeFusion{nullptr};
  function<te::OpBuildResCode(uint64_t, uint64_t, ge::Node &)> FuzzBuildTbeOp{nullptr};
  function<te::LX_QUERY_STATUS(const te::TbeOpInfo &, std::string &)> GetOpInfo{nullptr};
  function<bool(const std::map<std::string, std::string> &, bool *)> TbeInitialize{nullptr};
  function<bool()> TbeFinalize{nullptr};
  function<bool(uint64_t, vector<te::FinComTask> &)> WaitAllFinished{nullptr};

  struct CompileTaskPara {
    uint64_t task_num;
    std::unordered_map<uint64_t, int64_t> task_scope_id;
    vector<te::FinComTask> failed_tasks;
    vector<te::FinComTask> succ_tasks;
    map<int64_t, std::string> *json_path_map;
    ScopeNodeIdMap *fusion_nodes_map;
    std::unordered_map<uint64_t, ge::Node *> task_node_map;
    std::unordered_map<uint64_t, TbeOpInfoPtr> task_tbe_info_map;
  };
  std::string engine_name_;
  bool init_flag{false};
  bool support_parallel_compile{false};
  bool ConvertCheckSupportResult(const ge::OpDesc &op_desc, te::CheckSupportedResult &is_supported);

  Status SetOpJsonPath(ge::OpDescPtr &compile_op_desc, map<int64_t, std::string> &json_path_map, int scope_idx);

  Status ParallelCompileOp(ScopeNodeIdMap &fusion_nodes_map, map<int64_t, std::string> &json_path_map,
                           std::vector<ge::NodePtr> &buff_fus_compile_failed_nodes,
                           const std::vector<ge::NodePtr> &buff_fus_to_del_nodes,
                           bool ignore_compile_strategy = false, int64_t scope_id_minimum = 0);

  Status WaitTaskFinish(CompileTaskPara &task_para);

  Status PreCompileOp(ge::Node *node, OpKernelInfoPtr op_kernel_info_ptr, const std::string &imply_type_str,
                      const std::string &op_dsl_file_path, const std::string &session_graph_id);

  Status ProcessSuccCompileTask(CompileTaskPara &task_para);

  Status ProcessFailCompileTask(CompileTaskPara &task_para, std::vector<ge::NodePtr> &buff_fus_compile_failed_nodes,
                                bool ignore_compile_strategy = false, int64_t scope_id_minimum = 0);

  Status ProcessFailedCompileTask(CompileTaskPara &task_para, std::vector<ge::NodePtr> &buff_fus_compile_failed_nodes,
                                  bool ignore_compile_strategy, int64_t scope_id_minimum);

  Status ProcessSuccPreCompTask(CompileTaskPara &task_para);

  Status ProcessFailPreCompTask(CompileTaskPara &task_para);

  Status DoFuzzBuildTbeOp(std::vector<ge::Node *> &node_vec, uint64_t taskId, uint64_t thread_id);

  Status SetTeTask(vector<ge::Node *> &node_vec, CompileTaskPara &task_para, uint64_t taskId,
                   const std::vector<ge::NodePtr> &buff_fus_to_del_nodes, bool ignore_compile_strategy = false);

  Status GetTbeOpStoreInfo(const ge::OpDesc &op_desc, const OpKernelInfoPtr &op_kernel_info_ptr,
                           FEOpsStoreInfo &op_store_info);

  TbeInfoAssemblerPtr tbe_info_assembler_ptr_;

  TbeSingleOpInfoAssemblerPtr tbe_single_op_info_assembler_ptr_;

  Status SetPreCompilePattern(ge::OpDescPtr op_desc, te::TbeOpInfo &op_info, string &op_pattern_before_buff_fus);

  TbeOpInfoPtr PreCompSetTbeOpInfo(PreCompileNodePara &compile_para);

  Status ParallelPreCompileOp(vector<PreCompileNodePara> &compile_para_vec);

  Status SerialPreCompileOp(vector<PreCompileNodePara> &compile_para_vec);

  void ChangeBufferOptimize(const std::map<std::string, std::string> &options,
                            std::map<std::string, std::string> &new_options);

  Status SetOpCompileInfo(std::vector<ge::Node *> &nodes, const ge::OpDescPtr &op_desc_ptr);

  Status SetSupportDynamicShape(std::vector<ge::Node *> &nodes);

  bool StopCompileOpInTuningAndAfterUBMatchMode();

  bool StopWaitTaskFinishInTuningAndAfterBuilderMode(bool ignore_compile_strategy);

  void SetFusionFailedId(const vector<ge::Node *> &fusion_nodes, const int64_t &fusion_failed_id);

  void SetCustomFlag(ScopeNodeIdMap &fusion_nodes_map);

  ScopeAllocatorPtr scope_allocator_ptr_;

  // initialize required tbe api for tbe adapter
  Status InitTbeFunctions(PluginManagerPtr &plugin_manager_ptr);

  Status FillInTaskParam(ScopeNodeIdMap &fusion_nodes_map, map<int64_t, std::string> &json_path_map,
                         const std::vector<ge::NodePtr> &buff_fus_to_del_nodes, CompileTaskPara &task_para,
                         bool ignore_compile_strategy);

  void RollBackAttributes(std::vector<ge::Node *> &failed_nodes);
};
}  // namespace fe

#endif  // FUSION_ENGINE_OPTIMIZER_ADAPTER_TBE_ADAPTER_TBE_OP_STORE_ADAPTER_H_
